{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Visualizing Thinc models (with shape inference)\n",
    "\n",
    "This is a simple notebook showing how you can easily visualize your Thinc models and their inputs and outputs using [Graphviz](https://www.graphviz.org/) and [`pydot`](https://github.com/pydot/pydot). If you're installing `pydot` via the notebook, make sure to restart your kernel (or Google Colab VM – [see here](https://stackoverflow.com/questions/49853303/how-to-install-pydot-graphviz-on-google-colab) for details) after installation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "!pip install \"thinc>=8.0.0\" pydot graphviz svgwrite"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's start by defining a model with a number of layers chained together using the `chain` combinator:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from thinc.api import chain, expand_window, Relu, Maxout, Linear, Softmax\n",
    "\n",
    "n_hidden = 32\n",
    "dropout = 0.2\n",
    "\n",
    "model= chain(\n",
    "    expand_window(3),\n",
    "    Relu(nO=n_hidden, dropout=dropout, normalize=True),\n",
    "    Maxout(nO=n_hidden * 4),\n",
    "    Linear(nO=n_hidden * 2),\n",
    "    Relu(nO=n_hidden, dropout=dropout, normalize=True),\n",
    "    Linear(nO=n_hidden),\n",
    "    Relu(nO=n_hidden, dropout=dropout),\n",
    "    Softmax(),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here's the visualization we want to achieve for this model: the **name of the layer** or combination of layers and the **input and output dimensions**. Note that `>>` refers to a chaining of layers.\n",
    "\n",
    "![](https://user-images.githubusercontent.com/13643239/72763790-b8ab3900-3be5-11ea-8ea7-91e4e9fbb181.png)\n",
    "\n",
    "This means we need to add a node for each layer, edges connecting the nodes to the previous node (except for the first/last), and labels like `\"name|(nO,nI)\"` – for instance, `\"maxout|(128,32)\"`. Here's a simple function that takes a Thinc layer (i.e. a `Model` instance) and returns a label with the layer name and its dimensions, if available:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_label(layer):\n",
    "    layer_name = layer.name\n",
    "    nO = layer.get_dim(\"nO\") if layer.has_dim(\"nO\") else \"?\"\n",
    "    nI = layer.get_dim(\"nI\") if layer.has_dim(\"nI\") else \"?\"\n",
    "    return f\"{layer.name}|({nO}, {nI})\".replace(\">\", \"&gt;\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can now use `pydot` to create a visualization for a given model. You can customize the direction of the notes by setting `\"rankdir\"` (e.g. `\"TB\"` for \"top to bottom\") and adjust the font and arrow styling. To make the visualization render nicely in a notebook, we can call into IPython's utilities."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pydot\n",
    "from IPython.display import SVG, display\n",
    "\n",
    "def visualize_model(model):\n",
    "    dot = pydot.Dot()\n",
    "    dot.set(\"rankdir\", \"LR\")\n",
    "    dot.set_node_defaults(shape=\"record\", fontname=\"arial\", fontsize=\"10\")\n",
    "    dot.set_edge_defaults(arrowsize=\"0.7\")\n",
    "    nodes = {}\n",
    "    for i, layer in enumerate(model.layers):\n",
    "        label = get_label(layer)\n",
    "        node = pydot.Node(layer.id, label=label)\n",
    "        dot.add_node(node)\n",
    "        nodes[layer.id] = node\n",
    "        if i == 0:\n",
    "            continue\n",
    "        from_node = nodes[model.layers[i - 1].id]\n",
    "        to_node = nodes[layer.id]\n",
    "        if not dot.get_edge(from_node, to_node):\n",
    "            dot.add_edge(pydot.Edge(from_node, to_node))\n",
    "    display(SVG(dot.create_svg()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Calling `visualize_model` on the model defined above will render the visualization. However, most dimensions will now show up as `(?, ?)`, instead of the *actual* dimensions as shown in the graph above. That's because Thinc allows **defining models with missing shapes** and is able to **infer the missing shapes from the data** when you call `model.initialize`. The model visualized here doesn't define all its shapes, so the labels are incomplete."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "visualize_model(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To fill in the missing shapes, we can call `model.initialize` with examples of the expected input `X` and output `Y`. Running `visualize_model` again now shows the complete shapes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy\n",
    "\n",
    "X = numpy.zeros((5, 784), dtype=\"f\")\n",
    "Y = numpy.zeros((54000, 10), dtype=\"f\")\n",
    "model.initialize(X=X, Y=Y)\n",
    "\n",
    "visualize_model(model)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
